{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "07_RNN_practicals_colab.ipynb",
      "version": "0.3.2",
      "provenance": [],
      "include_colab_link": true
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.7"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/mlelarge/dataflowr/blob/master/CEA_EDF_INRIA/07_RNN_practicals_colab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-8ecFD9cgjM9",
        "colab_type": "text"
      },
      "source": [
        "# RNN practicals\n",
        "\n",
        "This jupyter notebook allows you to reproduce and explore the results presented in the [lecture on RNN](https://mlelarge.github.io/dataflowr-slides/X/lesson5.html#1)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Pc29CvyHgjNE",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import numpy as np\n",
        "from collections import OrderedDict\n",
        "import scipy.special\n",
        "from scipy.special import binom\n",
        "import matplotlib.pyplot as plt\n",
        "import time"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "utb4MovzgjNe",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def running_mean(x, N):\n",
        "    cumsum = np.cumsum(np.insert(x, 0, 0)) \n",
        "    return (cumsum[N:] - cumsum[:-N]) / float(N)\n",
        "\n",
        "def Catalan(k):\n",
        "    return binom(2*k,k)/(k+1)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JIGrOVLSgjNx",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch\n",
        "use_gpu = torch.cuda.is_available()\n",
        "def gpu(tensor, gpu=use_gpu):\n",
        "    if gpu:\n",
        "        return tensor.cuda()\n",
        "    else:\n",
        "        return tensor"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1Ytku22GgjOB",
        "colab_type": "text"
      },
      "source": [
        "# Generation datasets"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QfpbFNT1gjOF",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "seq_max_len = 20\n",
        "seq_min_len = 4"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uHAs_VxrgjOS",
        "colab_type": "text"
      },
      "source": [
        "## generating positive examples"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DKlEMULfgjOX",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# convention: +1 opening parenthesis and -1 closing parenthesis\n",
        "\n",
        "def all_parent(n, a, k=-1):\n",
        "    global res\n",
        "    if k==n-1 and sum(a) == 0:\n",
        "        res.append(a.copy())\n",
        "    elif k==n-1:\n",
        "        pass\n",
        "    else:\n",
        "        k += 1\n",
        "        if sum(a) > 0:\n",
        "            a[k] = 1\n",
        "            all_parent(n,a,k)\n",
        "        \n",
        "            a[k] = -1\n",
        "            all_parent(n,a,k)\n",
        "            a[k] = 0\n",
        "        else:\n",
        "            a[k] = 1\n",
        "            all_parent(n,a,k)\n",
        "            a[k] = 0"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dFcYswvXgjOl",
        "colab_type": "text"
      },
      "source": [
        "## generating negative examples"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZXm1-wS9gjOp",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def all_parent_mistake(n, a, k=-1):\n",
        "    global res\n",
        "    if k==n-1 and sum(a) >= -1 and sum(a) <= 1 and min(np.cumsum(a))<0:\n",
        "        res.append(a.copy())\n",
        "    elif sum(a) > n-k:\n",
        "        pass\n",
        "    elif k==n-1:\n",
        "        pass\n",
        "    else:\n",
        "        k += 1\n",
        "        if sum(a) >= -1 and k != 0:\n",
        "            a[k] = 1\n",
        "            all_parent_mistake(n,a,k)\n",
        "        \n",
        "            a[k] = -1\n",
        "            all_parent_mistake(n,a,k)\n",
        "            a[k] = 0\n",
        "        else:\n",
        "            a[k] = 1\n",
        "            all_parent_mistake(n,a,k)\n",
        "            a[k] = 0"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ePYlXZaYgjO6",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# numbering the parentheses\n",
        "# example: seq of len 6\n",
        "# ( ( ( ) ) ) \n",
        "# 0 1 2 4 5 6\n",
        "# we always have ( + ) = seq_len\n",
        "# 'wrong' parentheses are always closing and numbered as:\n",
        "# ) )\n",
        "# 7 8\n",
        "\n",
        "def reading_par(l, n):\n",
        "    res = [0]*len(l)\n",
        "    s = []\n",
        "    n_plus = -1\n",
        "    n_moins = n+1\n",
        "    c = 0\n",
        "    for i in l:\n",
        "        if i == 1:\n",
        "            n_plus += 1\n",
        "            s.append(n_plus)\n",
        "            res[c] = n_plus\n",
        "            c += 1\n",
        "        else:\n",
        "            try:\n",
        "                res[c] = n-s.pop()\n",
        "            except:\n",
        "                res[c] = n_moins\n",
        "                n_moins += 1\n",
        "            c += 1\n",
        "    return res"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dA0MR3f7gjPN",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "all_par = OrderedDict()\n",
        "for n in range(seq_min_len,seq_max_len+1,2):\n",
        "    a = [0]*n\n",
        "    res = []\n",
        "    all_parent(n=n,a=a,k=-1)\n",
        "    all_par[n] = [reading_par(k,n) for k in res]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JfMc7FvagjPi",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "all_par_mist = OrderedDict()\n",
        "for n in range(seq_min_len,seq_max_len+1,2):\n",
        "    a = [0]*n\n",
        "    res = []\n",
        "    all_parent_mistake(n=n,a=a,k=-1)\n",
        "    all_par_mist[n] = [reading_par(k,n) for k in res]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7PtA8pQIgjP2",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "all_par[6]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "26eGcBytgjQG",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "all_par_mist[6]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oCo1lHQpgjQT",
        "colab_type": "text"
      },
      "source": [
        "## number of negative examples by length"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "mdr-aeg8gjQX",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "long_mist = {i:len(l) for (i,l) in zip(all_par_mist.keys(),all_par_mist.values())}"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7gWfeHiEgjQf",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "long_mist"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zq8iVHxDgjQm",
        "colab_type": "text"
      },
      "source": [
        "## number of positive examples by length"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kxzqq0JCgjQo",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "Catalan_num = {i:len(l) for (i,l) in zip(all_par.keys(),all_par.values())}"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WmRoddtQgjQv",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "Catalan_num"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nzCECk4EgjQ5",
        "colab_type": "text"
      },
      "source": [
        "Sanity check, see [Catalan numbers](https://en.wikipedia.org/wiki/Catalan_number)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8Thfdkm8gjQ7",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "[(2*i,Catalan(i)) for i  in range(2,int(seq_max_len/2)+1)]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qri2vYemgjRF",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# nombre de suites correctes de longueur entre 4 et 10, alphabet de taille nb_symbol.\n",
        "nb_symbol = 10\n",
        "np.sum([Catalan(i)*int(nb_symbol/2)**i for i in range(2,int(seq_max_len/2)+1)])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LVc13MYfgjRR",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import random\n",
        "import torch\n",
        "\n",
        "class SequenceGenerator():\n",
        "    def __init__(self, nb_symbol = 10, seq_min_len = 4, seq_max_len = 10):\n",
        "        self.nb_symbol = nb_symbol\n",
        "        self.seq_min_len = seq_min_len\n",
        "        self.seq_max_len = seq_max_len\n",
        "        self.population = [i for i in range(int(nb_symbol/2))]\n",
        "                \n",
        "    def generate_pattern(self):\n",
        "        len_r = random.randint(self.seq_min_len/2,self.seq_max_len/2)\n",
        "        pattern = random.choices(self.population,k=len_r)\n",
        "        return pattern + pattern[::-1]\n",
        "    \n",
        "    def generate_pattern_parenthesis(self, len_r = None):\n",
        "        if len_r == None:\n",
        "            len_r = int(2*random.randint(self.seq_min_len/2,self.seq_max_len/2))\n",
        "        pattern = np.random.choice(self.population,size=int(len_r/2),replace=True)\n",
        "        ind_r = random.randint(0,Catalan_num[len_r]-1)\n",
        "        res = [pattern[i] if i <= len_r/2 else self.nb_symbol-1-pattern[len_r-i] for i in all_par[len_r][ind_r]]\n",
        "        return res\n",
        "    \n",
        "    def generate_parenthesis_false(self):\n",
        "        len_r = int(2*random.randint(self.seq_min_len/2,self.seq_max_len/2))\n",
        "        pattern = np.random.choice(self.population,size=int(len_r/2),replace=True)\n",
        "        ind_r = random.randint(0,long_mist[len_r]-1)\n",
        "        res = [pattern[i] if i <= len_r/2 \n",
        "               else  self.nb_symbol-1-pattern[len_r-i] if i<= len_r \n",
        "               else self.nb_symbol-1-pattern[i-len_r] for i in all_par_mist[len_r][ind_r]]\n",
        "        return res\n",
        "    \n",
        "    def generate_hard_parenthesis(self, len_r = None):\n",
        "        if len_r == None:\n",
        "            len_r = int(2*random.randint(self.seq_min_len/2,self.seq_max_len/2))\n",
        "        pattern = np.random.choice(self.population,size=int(len_r/2),replace=True)\n",
        "        ind_r = random.randint(0,Catalan_num[len_r]-1)\n",
        "        res = [pattern[i] if i <= len_r/2 else self.nb_symbol-1-pattern[len_r-i] for i in all_par[len_r][ind_r]]\n",
        "        \n",
        "        if len_r == None:\n",
        "            len_r = int(2*random.randint(self.seq_min_len/2,self.seq_max_len/2))\n",
        "        pattern = np.random.choice(self.population,size=int(len_r/2),replace=True)\n",
        "        ind_r = random.randint(0,Catalan_num[len_r]-1)\n",
        "        res2 = [pattern[i] if i <= len_r/2 else self.nb_symbol-1-pattern[len_r-i] for i in all_par[len_r][ind_r]]\n",
        "        return res + res2\n",
        "    \n",
        "    def generate_hard_nonparenthesis(self, len_r = None):\n",
        "        if len_r == None:\n",
        "            len_r = int(2*random.randint(self.seq_min_len/2,self.seq_max_len/2))\n",
        "        pattern = np.random.choice(self.population,size=int(len_r/2),replace=True)\n",
        "        ind_r = random.randint(0,long_mist[len_r]-1)\n",
        "        res = [pattern[i] if i <= len_r/2 \n",
        "               else  self.nb_symbol-1-pattern[len_r-i] if i<= len_r \n",
        "               else self.nb_symbol-1-pattern[i-len_r] for i in all_par_mist[len_r][ind_r]]\n",
        "        \n",
        "        if len_r == None:\n",
        "            len_r = int(2*random.randint(self.seq_min_len/2,self.seq_max_len/2))\n",
        "        pattern = np.random.choice(self.population,size=int(len_r/2),replace=True)\n",
        "        ind_r = random.randint(0,Catalan_num[len_r]-1)\n",
        "        res2 = [pattern[i] if i <= len_r/2 else self.nb_symbol-1-pattern[len_r-i] for i in all_par[len_r][ind_r]]\n",
        "        return  res +[self.nb_symbol-1-pattern[0]]+ res2\n",
        "        \n",
        "    def generate_false(self):\n",
        "        popu = [i for i in range(nb_symbol)]\n",
        "        len = random.randint(self.seq_min_len/2,self.seq_max_len/2)\n",
        "        return random.choices(popu,k=len) + random.choices(popu,k=len)\n",
        "    \n",
        "    def generate_label(self, x):\n",
        "        l = int(len(x)/2)\n",
        "        return 1 if x[:l] == x[:l-1:-1] else 0\n",
        "    \n",
        "    def generate_label_parenthesis(self, x):\n",
        "        s = []\n",
        "        label = 1\n",
        "        lenx = len(x)\n",
        "        for i in x:\n",
        "            if s == [] and i < self.nb_symbol/2:\n",
        "                s.append(i)\n",
        "            elif s == [] and i >= self.nb_symbol/2:\n",
        "                label = 0\n",
        "                break\n",
        "            elif i == self.nb_symbol-1-s[-1]:\n",
        "                s.pop()\n",
        "            else:\n",
        "                s.append(i)\n",
        "        if s != []:\n",
        "            label = 0\n",
        "        return label\n",
        "    \n",
        "    def one_hot(self,seq):\n",
        "        one_hot_seq = []\n",
        "        for s in seq:\n",
        "            one_hot = [0 for _ in range(self.nb_symbol)]\n",
        "            one_hot[s] = 1\n",
        "            one_hot_seq.append(one_hot)\n",
        "        return one_hot_seq\n",
        "    \n",
        "    def generate_input(self, len_r = None, true_parent = False, hard_false = True):\n",
        "        if true_parent:\n",
        "            seq = self.generate_pattern_parenthesis(len_r)\n",
        "        elif bool(random.getrandbits(1)):\n",
        "            seq = self.generate_pattern_parenthesis(len_r)\n",
        "        else:\n",
        "            if hard_false:\n",
        "                seq = self.generate_parenthesis_false()\n",
        "            else:\n",
        "                seq = self.generate_false()\n",
        "        return gpu(torch.from_numpy(np.array(self.one_hot(seq))).type(torch.FloatTensor)), gpu(torch.from_numpy(np.array([self.generate_label_parenthesis(seq)])))\n",
        "\n",
        "    def generate_input_hard(self,true_parent = False):\n",
        "        if true_parent:\n",
        "            seq = self.generate_hard_parenthesis(self.seq_max_len)\n",
        "        elif bool(random.getrandbits(1)):\n",
        "            seq = self.generate_hard_parenthesis(self.seq_max_len)\n",
        "        else:\n",
        "            seq = self.generate_hard_nonparenthesis(self.seq_max_len)\n",
        "            \n",
        "        return gpu(torch.from_numpy(np.array(self.one_hot(seq))).type(torch.FloatTensor)), gpu(torch.from_numpy(np.array([self.generate_label_parenthesis(seq)])))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ivBE7UIYgjRc",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "nb_symbol = 10\n",
        "generator = SequenceGenerator(nb_symbol = nb_symbol, seq_min_len = seq_min_len, seq_max_len = seq_max_len)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Dj5YgzvXgjRq",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "generator.generate_pattern_parenthesis()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "lbKlAmE_gjRz",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "x = generator.generate_parenthesis_false()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2kdSGWGCgjR7",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "generator.generate_label_parenthesis(x)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4hL2UnY6gjSC",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "generator.generate_input()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hkOTjH01gjSN",
        "colab_type": "text"
      },
      "source": [
        "# First RNN"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bLIQlP1QgjSP",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "\n",
        "\n",
        "class RecNet(nn.Module):\n",
        "    def __init__(self, dim_input=10, dim_recurrent=50, dim_output=2):\n",
        "        super(RecNet, self).__init__()\n",
        "        self.fc_x2h = nn.Linear(dim_input, dim_recurrent)\n",
        "        self.fc_h2h = nn.Linear(dim_recurrent, dim_recurrent, bias = False)\n",
        "        self.fc_h2y = nn.Linear(dim_recurrent, dim_output)\n",
        "        \n",
        "    def forward(self, x):\n",
        "        h = x.new_zeros(1, self.fc_h2y.weight.size(1))\n",
        "        for t in range(x.size(0)):\n",
        "            h = torch.relu(self.fc_x2h(x[t,:]) + self.fc_h2h(h))    \n",
        "        return self.fc_h2y(h)\n",
        "    \n",
        "RNN = gpu(RecNet(dim_input = nb_symbol))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "AcbhSU5ggjSW",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "cross_entropy = nn.CrossEntropyLoss()\n",
        "learning_rate = 1e-3\n",
        "optimizer = torch.optim.Adam(RNN.parameters(),lr=learning_rate)\n",
        "nb_train = 40000\n",
        "loss_t = []\n",
        "corrects =[]\n",
        "labels = []\n",
        "start = time.time()\n",
        "for k in range(nb_train):\n",
        "    x,l = generator.generate_input(hard_false = False)\n",
        "    y = RNN(x)\n",
        "    loss = cross_entropy(y,l)\n",
        "    _,preds = torch.max(y.data,1)\n",
        "    corrects.append(preds.item() == l.data.item())\n",
        "    optimizer.zero_grad()\n",
        "    loss.backward()\n",
        "    optimizer.step()\n",
        "    loss_t.append(loss)\n",
        "    labels.append(l.data)\n",
        "print(time.time() - start)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NilV4syzgjSi",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "plt.plot(running_mean(loss_t,int(nb_train/100)))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LINKeGNjgjSs",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "plt.plot(running_mean(corrects,int(nb_train/100)))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tIuGba5sgjSx",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "plt.plot(np.cumsum(labels))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "O2A1jCLJgjS5",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "nb_test = 1000\n",
        "corrects_test =[]\n",
        "labels_test = []\n",
        "for k in range(nb_test):\n",
        "    x,l = generator.generate_input(len_r=seq_max_len,true_parent=True)\n",
        "    y = RNN(x)\n",
        "    _,preds = torch.max(y.data,1)\n",
        "    corrects_test.append(preds.item() == l.data.item())\n",
        "    labels_test.append(l.data)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "z41nP-UhgjS-",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "np.sum(corrects_test)/nb_test"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "lzOpkaLigjTD",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "nb_test = 1000\n",
        "corrects_test =[]\n",
        "labels_test = []\n",
        "for k in range(nb_test):\n",
        "    x,l = generator.generate_input(len_r=seq_max_len, hard_false = True)\n",
        "    y = RNN(x)\n",
        "    _,preds = torch.max(y.data,1)\n",
        "    corrects_test.append(preds.item() == l.data.item())\n",
        "    labels_test.append(l.data)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Geelj-5XgjTM",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "np.sum(corrects_test)/nb_test"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "crSKTEJHgjTU",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "nb_test = 1000\n",
        "correctsh_test =[]\n",
        "labelsh_test = []\n",
        "for k in range(nb_test):\n",
        "    x,l = generator.generate_input_hard()\n",
        "    y = RNN(x)\n",
        "    _,preds = torch.max(y.data,1)\n",
        "    correctsh_test.append(preds.item() == l.data.item())\n",
        "    labelsh_test.append(l.data)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "br6HRXOhgjTc",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "np.sum(correctsh_test)/nb_test"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "I5q30nkqgjTn",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "nb_test = 1000\n",
        "correctsh_test =[]\n",
        "labelsh_test = []\n",
        "for k in range(nb_test):\n",
        "    x,l = generator.generate_input_hard(true_parent=True)\n",
        "    y = RNN(x)\n",
        "    _,preds = torch.max(y.data,1)\n",
        "    correctsh_test.append(preds.item() == l.data.item())\n",
        "    labelsh_test.append(l.data)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "L6ZM_4c2gjTw",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "np.sum(correctsh_test)/nb_test"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wvaIEg6igjT3",
        "colab_type": "text"
      },
      "source": [
        "# RNN with Gating "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kOiH5WsLgjT5",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class RecNetGating(nn.Module):\n",
        "    def __init__(self, dim_input=10, dim_recurrent=50, dim_output=2):\n",
        "        super(RecNetGating, self).__init__()\n",
        "        self.fc_x2h = nn.Linear(dim_input, dim_recurrent)\n",
        "        self.fc_h2h = nn.Linear(dim_recurrent, dim_recurrent, bias = False)\n",
        "        self.fc_x2z = nn.Linear(dim_input, dim_recurrent)\n",
        "        self.fc_h2z = nn.Linear(dim_recurrent,dim_recurrent, bias = False)\n",
        "        self.fc_h2y = nn.Linear(dim_recurrent, dim_output)\n",
        "        \n",
        "    def forward(self, x):\n",
        "        h = x.new_zeros(1, self.fc_h2y.weight.size(1))\n",
        "        for t in range(x.size(0)):\n",
        "            z = torch.sigmoid(self.fc_x2z(x[t,:])+self.fc_h2z(h))\n",
        "            hb = torch.relu(self.fc_x2h(x[t,:]) + self.fc_h2h(h))\n",
        "            h = z * h + (1-z) * hb   \n",
        "        return self.fc_h2y(h)    \n",
        "    \n",
        "RNNG = gpu(RecNetGating(dim_input = nb_symbol))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NZuKhEg-gjUB",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "optimizerG = torch.optim.Adam(RNNG.parameters(),lr=1e-3)\n",
        "loss_tG = []\n",
        "correctsG =[]\n",
        "labelsG = []\n",
        "start = time.time()\n",
        "for k in range(nb_train):\n",
        "    x,l = generator.generate_input(hard_false = False)\n",
        "    y = RNNG(x)\n",
        "    loss = cross_entropy(y,l)\n",
        "    _,preds = torch.max(y.data,1)\n",
        "    correctsG.append(preds.item() == l.data.item())\n",
        "    optimizerG.zero_grad()\n",
        "    loss.backward()\n",
        "    optimizerG.step()\n",
        "    loss_tG.append(loss)\n",
        "    labelsG.append(l.item())\n",
        "print(time.time() - start)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6BbKxWkTgjUI",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "plt.plot(running_mean(loss_tG,int(nb_train/50)))\n",
        "plt.plot(running_mean(loss_t,int(nb_train/50)))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Vjp8ybXqgjUO",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "plt.plot(running_mean(correctsG,int(nb_train/50)))\n",
        "plt.plot(running_mean(corrects,int(nb_train/50)))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "opLeQwWogjUU",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "nb_test = 1000\n",
        "correctsG_test =[]\n",
        "labelsG_test = []\n",
        "for k in range(nb_test):\n",
        "    x,l = generator.generate_input(len_r=seq_max_len,true_parent=True)\n",
        "    y = RNNG(x)\n",
        "    _,preds = torch.max(y.data,1)\n",
        "    correctsG_test.append(preds.item() == l.data.item())\n",
        "    labelsG_test.append(l.data)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FKF7r8dSgjUc",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "np.sum(correctsG_test)/nb_test"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "IwQg9tXRgjUq",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "nb_test = 1000\n",
        "correctsG_test =[]\n",
        "labelsG_test = []\n",
        "for k in range(nb_test):\n",
        "    x,l = generator.generate_input(len_r=seq_max_len, hard_false = True)\n",
        "    y = RNNG(x)\n",
        "    _,preds = torch.max(y.data,1)\n",
        "    correctsG_test.append(preds.item() == l.data.item())\n",
        "    labelsG_test.append(l.data)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QdWoh4UTgjUv",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "np.sum(correctsG_test)/nb_test"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "OkkC1_jMgjU2",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "nb_test = 1000\n",
        "correctshG_test =[]\n",
        "labelshG_test = []\n",
        "for k in range(nb_test):\n",
        "    x,l = generator.generate_input_hard()\n",
        "    y = RNNG(x)\n",
        "    _,preds = torch.max(y.data,1)\n",
        "    correctshG_test.append(preds.item() == l.data.item())\n",
        "    labelshG_test.append(l.data)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kUMHjsXrgjVB",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "np.sum(correctshG_test)/nb_test"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9GO-LItVgjVN",
        "colab_type": "text"
      },
      "source": [
        "# LSTM"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4gTATkvrgjVR",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class LSTMNet(nn.Module):\n",
        "    def __init__(self, dim_input=10, dim_recurrent=50, num_layers=4, dim_output=2):\n",
        "        super(LSTMNet, self).__init__()\n",
        "        self.lstm = nn.LSTM(input_size = dim_input, \n",
        "                           hidden_size = dim_recurrent,\n",
        "                           num_layers = num_layers)\n",
        "        self.fc_o2y = nn.Linear(dim_recurrent,dim_output)\n",
        "        \n",
        "    def forward(self, x):\n",
        "        x = x.unsqueeze(1)\n",
        "        output, _ = self.lstm(x)\n",
        "        output = output.squeeze(1)\n",
        "        output = output.narrow(0, output.size(0)-1,1)\n",
        "        return self.fc_o2y(F.relu(output))\n",
        "    \n",
        "lstm = gpu(LSTMNet(dim_input = nb_symbol))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WkjccxdsgjVX",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "x, l = generator.generate_input()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sgdPdLupgjVf",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "lstm(x)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Rp2KI_eSgjVk",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "optimizerL = torch.optim.Adam(lstm.parameters(),lr=1e-3)\n",
        "loss_tL = []\n",
        "correctsL =[]\n",
        "labelsL = []\n",
        "start = time.time()\n",
        "for k in range(nb_train):\n",
        "    x,l = generator.generate_input(hard_false = False)\n",
        "    y = lstm(x)\n",
        "    loss = cross_entropy(y,l)\n",
        "    _,preds = torch.max(y.data,1)\n",
        "    correctsL.append(preds.item() == l.data.item())\n",
        "    optimizerL.zero_grad()\n",
        "    loss.backward()\n",
        "    optimizerL.step()\n",
        "    loss_tL.append(loss)\n",
        "    labelsL.append(l.item())\n",
        "print(time.time() - start)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LwLenhrCgjVo",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "plt.plot(running_mean(loss_tL,int(nb_train/50)))\n",
        "plt.plot(running_mean(loss_tG,int(nb_train/50)))\n",
        "plt.plot(running_mean(loss_t,int(nb_train/50)))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Ng_wRELigjVt",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "plt.plot(running_mean(correctsL,int(nb_train/50)))\n",
        "plt.plot(running_mean(correctsG,int(nb_train/50)))\n",
        "plt.plot(running_mean(corrects,int(nb_train/50)))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oFfEMA89gjVw",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "nb_test = 1000\n",
        "correctsL_test =[]\n",
        "labelsL_test = []\n",
        "for k in range(nb_test):\n",
        "    x,l = generator.generate_input(len_r=seq_max_len,true_parent=True)\n",
        "    y = lstm(x)\n",
        "    _,preds = torch.max(y.data,1)\n",
        "    correctsL_test.append(preds.item() == l.data.item())\n",
        "    labelsL_test.append(l.data)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "mn4MCPRDgjV2",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "np.sum(correctsL_test)/nb_test"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "CIvIS-gEgjV-",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "nb_test = 1000\n",
        "correctsL_test =[]\n",
        "labelsL_test = []\n",
        "for k in range(nb_test):\n",
        "    x,l = generator.generate_input(len_r=seq_max_len,true_parent=False,hard_false = True)\n",
        "    y = lstm(x)\n",
        "    _,preds = torch.max(y.data,1)\n",
        "    correctsL_test.append(preds.item() == l.data.item())\n",
        "    labelsL_test.append(l.data)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "b06Gp7lrgjWC",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "np.sum(correctsL_test)/nb_test"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rniUHL4mgjWE",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "nb_test = 1000\n",
        "correctshL_test =[]\n",
        "labelshL_test = []\n",
        "for k in range(nb_test):\n",
        "    x,l = generator.generate_input_hard()\n",
        "    y = lstm(x)\n",
        "    _,preds = torch.max(y.data,1)\n",
        "    correctshL_test.append(preds.item() == l.data.item())\n",
        "    labelshL_test.append(l.data)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Zh43kUcGgjWI",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "np.sum(correctshL_test)/nb_test"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mWUSVdiggjWO",
        "colab_type": "text"
      },
      "source": [
        "# GRU\n",
        "\n",
        "Implement your RNN with a [GRU](https://pytorch.org/docs/stable/nn.html#gru)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JSH65GakgjWV",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OvHhG7KfgjWe",
        "colab_type": "text"
      },
      "source": [
        "# Explore!\n",
        "\n",
        "What are good negative examples?\n",
        "\n",
        "How to be sure that your network 'generalizes'?"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "lCVb66-xgjWj",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}